{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read the data file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- Names: string (nullable = true)\n",
      " |-- Age: double (nullable = true)\n",
      " |-- Total_Purchase: double (nullable = true)\n",
      " |-- Account_Manager: integer (nullable = true)\n",
      " |-- Years: double (nullable = true)\n",
      " |-- Num_Sites: double (nullable = true)\n",
      " |-- Onboard_date: timestamp (nullable = true)\n",
      " |-- Location: string (nullable = true)\n",
      " |-- Company: string (nullable = true)\n",
      " |-- Churn: integer (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "\n",
    "spark = SparkSession.builder.appName('churn_log_reg').getOrCreate()\n",
    "df = spark.read.csv('customer_churn.csv', inferSchema = True, header = True)\n",
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Names',\n",
       " 'Age',\n",
       " 'Total_Purchase',\n",
       " 'Account_Manager',\n",
       " 'Years',\n",
       " 'Num_Sites',\n",
       " 'Onboard_date',\n",
       " 'Location',\n",
       " 'Company',\n",
       " 'Churn']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Summary Statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>summary</th>\n",
       "      <td>count</td>\n",
       "      <td>mean</td>\n",
       "      <td>stddev</td>\n",
       "      <td>min</td>\n",
       "      <td>max</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Names</th>\n",
       "      <td>900</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>Aaron King</td>\n",
       "      <td>Zachary Walsh</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Age</th>\n",
       "      <td>900</td>\n",
       "      <td>41.81666666666667</td>\n",
       "      <td>6.127560416916251</td>\n",
       "      <td>22.0</td>\n",
       "      <td>65.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Total_Purchase</th>\n",
       "      <td>900</td>\n",
       "      <td>10062.82403333334</td>\n",
       "      <td>2408.644531858096</td>\n",
       "      <td>100.0</td>\n",
       "      <td>18026.01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Account_Manager</th>\n",
       "      <td>900</td>\n",
       "      <td>0.4811111111111111</td>\n",
       "      <td>0.4999208935073339</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Years</th>\n",
       "      <td>900</td>\n",
       "      <td>5.27315555555555</td>\n",
       "      <td>1.274449013194616</td>\n",
       "      <td>1.0</td>\n",
       "      <td>9.15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Num_Sites</th>\n",
       "      <td>900</td>\n",
       "      <td>8.587777777777777</td>\n",
       "      <td>1.7648355920350969</td>\n",
       "      <td>3.0</td>\n",
       "      <td>14.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Location</th>\n",
       "      <td>900</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>00103 Jeffrey Crest Apt. 205 Padillaville, IA ...</td>\n",
       "      <td>Unit 9800 Box 2878 DPO AA 75157</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Company</th>\n",
       "      <td>900</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>Abbott-Thompson</td>\n",
       "      <td>Zuniga, Clark and Shaffer</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Churn</th>\n",
       "      <td>900</td>\n",
       "      <td>0.16666666666666666</td>\n",
       "      <td>0.3728852122772358</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     0                    1                   2  \\\n",
       "summary          count                 mean              stddev   \n",
       "Names              900                 None                None   \n",
       "Age                900    41.81666666666667   6.127560416916251   \n",
       "Total_Purchase     900    10062.82403333334   2408.644531858096   \n",
       "Account_Manager    900   0.4811111111111111  0.4999208935073339   \n",
       "Years              900     5.27315555555555   1.274449013194616   \n",
       "Num_Sites          900    8.587777777777777  1.7648355920350969   \n",
       "Location           900                 None                None   \n",
       "Company            900                 None                None   \n",
       "Churn              900  0.16666666666666666  0.3728852122772358   \n",
       "\n",
       "                                                                 3  \\\n",
       "summary                                                        min   \n",
       "Names                                                   Aaron King   \n",
       "Age                                                           22.0   \n",
       "Total_Purchase                                               100.0   \n",
       "Account_Manager                                                  0   \n",
       "Years                                                          1.0   \n",
       "Num_Sites                                                      3.0   \n",
       "Location         00103 Jeffrey Crest Apt. 205 Padillaville, IA ...   \n",
       "Company                                            Abbott-Thompson   \n",
       "Churn                                                            0   \n",
       "\n",
       "                                               4  \n",
       "summary                                      max  \n",
       "Names                              Zachary Walsh  \n",
       "Age                                         65.0  \n",
       "Total_Purchase                          18026.01  \n",
       "Account_Manager                                1  \n",
       "Years                                       9.15  \n",
       "Num_Sites                                   14.0  \n",
       "Location         Unit 9800 Box 2878 DPO AA 75157  \n",
       "Company                Zuniga, Clark and Shaffer  \n",
       "Churn                                          1  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.describe().toPandas().transpose()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build the Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler\n",
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.classification import LogisticRegression\n",
    "\n",
    "assembler = VectorAssembler(inputCols = ['Age', 'Total_Purchase', 'Account_Manager', 'Years', 'Num_Sites',], outputCol = 'features')\n",
    "log_reg = LogisticRegression(featuresCol = 'features', labelCol = 'Churn', maxIter=10)\n",
    "pipeline = Pipeline(stages = [assembler, log_reg])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = df.randomSplit([0.7, 0.3])\n",
    "lrModel = pipeline.fit(train)\n",
    "predictions = lrModel.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- Names: string (nullable = true)\n",
      " |-- Age: double (nullable = true)\n",
      " |-- Total_Purchase: double (nullable = true)\n",
      " |-- Account_Manager: integer (nullable = true)\n",
      " |-- Years: double (nullable = true)\n",
      " |-- Num_Sites: double (nullable = true)\n",
      " |-- Onboard_date: timestamp (nullable = true)\n",
      " |-- Location: string (nullable = true)\n",
      " |-- Company: string (nullable = true)\n",
      " |-- Churn: integer (nullable = true)\n",
      " |-- features: vector (nullable = true)\n",
      " |-- rawPrediction: vector (nullable = true)\n",
      " |-- probability: vector (nullable = true)\n",
      " |-- prediction: double (nullable = false)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions.printSchema()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Make Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------+--------------------+--------------------+\n",
      "|Churn|prediction|         probability|       rawPrediction|\n",
      "+-----+----------+--------------------+--------------------+\n",
      "|    0|       0.0|[0.92746093189776...|[2.54832538455355...|\n",
      "|    0|       0.0|[0.99571902458018...|[5.44928422849747...|\n",
      "|    0|       0.0|[0.87982491526622...|[1.99077320888602...|\n",
      "|    0|       0.0|[0.97001076460616...|[3.47646867190283...|\n",
      "+-----+----------+--------------------+--------------------+\n",
      "only showing top 4 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions.select('Churn', 'prediction', 'probability', 'rawPrediction').show(4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Performance Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8076314401266057"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "\n",
    "eval = BinaryClassificationEvaluator(labelCol = 'Churn', rawPredictionCol = 'rawPrediction')\n",
    "eval.evaluate(predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict on the new data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- Names: string (nullable = true)\n",
      " |-- Age: double (nullable = true)\n",
      " |-- Total_Purchase: double (nullable = true)\n",
      " |-- Account_Manager: integer (nullable = true)\n",
      " |-- Years: double (nullable = true)\n",
      " |-- Num_Sites: double (nullable = true)\n",
      " |-- Onboard_date: timestamp (nullable = true)\n",
      " |-- Location: string (nullable = true)\n",
      " |-- Company: string (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "new_data = spark.read.csv('new_customers.csv', inferSchema=True, header=True)\n",
    "new_data.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "lrModel_new = pipeline.fit(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions_new = lrModel_new.transform(new_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- Names: string (nullable = true)\n",
      " |-- Age: double (nullable = true)\n",
      " |-- Total_Purchase: double (nullable = true)\n",
      " |-- Account_Manager: integer (nullable = true)\n",
      " |-- Years: double (nullable = true)\n",
      " |-- Num_Sites: double (nullable = true)\n",
      " |-- Onboard_date: timestamp (nullable = true)\n",
      " |-- Location: string (nullable = true)\n",
      " |-- Company: string (nullable = true)\n",
      " |-- features: vector (nullable = true)\n",
      " |-- rawPrediction: vector (nullable = true)\n",
      " |-- probability: vector (nullable = true)\n",
      " |-- prediction: double (nullable = false)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions_new.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------+----------------+----------+\n",
      "|         Names|         Company|prediction|\n",
      "+--------------+----------------+----------+\n",
      "| Andrew Mccall|        King Ltd|       0.0|\n",
      "|Michele Wright|   Cannon-Benson|       1.0|\n",
      "|  Jeremy Chang|Barron-Robertson|       1.0|\n",
      "|Megan Ferguson|   Sexton-Golden|       1.0|\n",
      "|  Taylor Young|        Wood LLC|       0.0|\n",
      "| Jessica Drake|   Parks-Robbins|       1.0|\n",
      "+--------------+----------------+----------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions_new.select('Names', 'Company', 'prediction').show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
